import json
import re
import os
from collections import defaultdict, Counter
from math import log
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

def extract_actions(cot):
    pattern = r'【.*?(Decomposition|Reflection|Verification|However|Retry|Transition|Alternative).*?】'
    actions = re.findall(pattern, cot)
    return [a.lower() for a in actions]

def calculate_task_lengths(dataset, task_name):
    length_stats = defaultdict(list)
    
    for item in dataset:
        length_type = item["Length"]
        word_count = len(item['CoT'].split())
        length_stats[length_type].append(word_count)

    result = {}
    for length_type, lengths in length_stats.items():
        result[length_type] = {
            "mean": np.mean(lengths) if lengths else 0,
            "count": len(lengths),
            "lengths": lengths 
        }
    
    return result

def analyze_length_distribution(task_stats, task_name):
    plt.figure(figsize=(12, 6))
    
    for i, length_type in enumerate(["Short", "Long"], 1):
        if length_type not in task_stats:
            continue
            
        lengths = task_stats[length_type]["lengths"]
        if not lengths:
            continue

        mean = np.mean(lengths)
        median = np.median(lengths)
        std = np.std(lengths)
        skewness = stats.skew(lengths)
        kurtosis = stats.kurtosis(lengths)

        _, p_normal = stats.normaltest(lengths)

        plt.subplot(1, 2, i)
        sns.histplot(lengths, kde=True, color='skyblue' if length_type == "Short" else 'salmon')
        plt.axvline(mean, color='r', linestyle='--', label=f'mean {mean:.1f}')
        plt.axvline(median, color='g', linestyle=':', label=f'median {median:.1f}')
        plt.title(f'{task_name} - {length_type}length distribution\n(n={len(lengths)})')
        plt.xlabel('word')
        plt.ylabel('count')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'{task_name}_length_distribution.png')
    plt.show()

def calculate_dynamic_length_range(task_name, length_type, task_stats):
  
    PRESET_IDEAL_LENGTHS = {
        "sentiment": {"Short": 70, "Long": 180},
        "emotion": {"Short": 90, "Long": 220},
        "sarcasm": {"Short": 110, "Long": 280},
        "humor": {"Short": 100, "Long": 260},
        "default": {"Short": 80, "Long": 200}
    }

    TOLERANCE_FACTORS = {
        "Short": 0.25,
        "Long": 0.4
    }
  
    preset = PRESET_IDEAL_LENGTHS.get(task_name, PRESET_IDEAL_LENGTHS["default"])
    preset_short = preset["Short"]
    preset_long = preset["Long"]
    
    stat_short = task_stats.get("Short", {}).get("mean", preset_short)
    stat_long = task_stats.get("Long", {}).get("mean", preset_long)
    
    sample_count = task_stats.get(length_type, {}).get("count", 0)
    
    lengths = task_stats.get(length_type, {}).get("lengths", [])
    
    if lengths and len(lengths) >= 30:
        min_length = int(np.percentile(lengths, 5))
        max_length = int(np.percentile(lengths, 95))
        base_length = np.median(lengths)  

        return min_length, max_length, base_length

    alpha = min(1.0, sample_count / 100) 

    if length_type == "Short":
        base_length = alpha * stat_short + (1 - alpha) * preset_short
    else:
        base_length = alpha * stat_long + (1 - alpha) * preset_long

    tolerance = TOLERANCE_FACTORS[length_type]
    min_length = int(base_length * (1 - tolerance))
    max_length = int(base_length * (1 + tolerance))
    
    return min_length, max_length, base_length

def calculate_diversity(dataset):
    label_dist = Counter(item['Label'] for item in dataset)
    total = len(dataset)
    return -sum((count/total) * log(count/total) for count in label_dist.values())

def generate_rl_dataset(input_path, output_path, max_samples=5000):
    with open(input_path, 'r', encoding='utf-8') as f:
        full_data = json.load(f)
   
    task_name = os.path.splitext(os.path.basename(input_path))[0]
    
    task_stats = calculate_task_lengths(full_data, task_name)
    
    analyze_length_distribution(task_stats, task_name)
    
    length_ranges = {}
    for length_type in ["Short", "Long"]:
        if length_type in task_stats:
            min_len, max_len, base_len = calculate_dynamic_length_range(
                task_name, length_type, task_stats
            )
            length_ranges[length_type] = {
                "min_length": min_len,
                "max_length": max_len,
                "base_length": base_len
            }
    
    target_size = max(1, len(full_data) // 8)
    if target_size > max_samples:
        target_size = max_samples

    for item in full_data:
        if 'ReasoningType' not in item:
            item['ReasoningType'] = 'Unknown'
       
    strata_keys = [
        f"{item['Label']}-{item['ReasoningType']}-{item['Length']}"
        for item in full_data
    ]
    
    if len(full_data) > target_size:
        sampled_data, _ = train_test_split(
            full_data, 
            train_size=target_size,
            stratify=strata_keys,
            random_state=42
        )
    else:
        sampled_data = full_data
    
    def calculate_distribution(data):
        dist = defaultdict(float)
        total = len(data)
        for item in data:
            key = f"{item['Label']}-{item['ReasoningType']}-{item['Length']}"
            dist[key] += 1
        
        for key in dist:
            dist[key] = dist[key] / total * 100
        return dist
    
    orig_dist = calculate_distribution(full_data)
    sampled_dist = calculate_distribution(sampled_data)
    
    max_deviation = 0
    for key in orig_dist:
        orig_percent = orig_dist[key]
        sampled_percent = sampled_dist.get(key, 0)
        deviation = abs(orig_percent - sampled_percent)
        if deviation > max_deviation:
            max_deviation = deviation
    
    dataset = []
    action_counter = Counter()
    for item in sampled_data:
        actions = extract_actions(item['CoT'])
        action_counter.update(actions)
        
        length_type = item["Length"]
        
        range_data = length_ranges.get(length_type, {})
        
        dataset.append({
            "prompt": item['text'],
            "reference_answer": f"{item['CoT']}\nFinal Label: {item['Label']}",
            "reference_actions": actions,
            "reference_length": len(item['CoT'].split()),
            "min_length": range_data.get("min_length", 0),
            "max_length": range_data.get("max_length", 0),
            "base_length": range_data.get("base_length", 0),
            "label": item['Label'],
            "reasoning_type": item['ReasoningType'],
            "length_type": length_type,
            "task_type": task_name  
        })
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for item in dataset:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    for action, count in action_counter.most_common():
        print(f"【{action.capitalize()}】: {count}")
    
    return dataset

if __name__ == "__main__":
    input_file = "xxx.json"  
    output_file = "xxx.jsonl"
    
    rl_dataset = generate_rl_dataset(input_file, output_file)
    
    if rl_dataset:
        sample = rl_dataset[0]